{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tianshou.data import Batch, ReplayBuffer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xoPiGVD8LNma"
   },
   "source": [
    "# Buffer\n",
    "Replay Buffer is a very common module in DRL implementations. In Tianshou, the Buffer module can be viewed as a specialized form of Batch, designed to track all data trajectories and offering utilities like sampling methods beyond basic storage.\n",
    "\n",
    "There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial [Vectorized Environment](https://tianshou.readthedocs.io/en/master/02_notebooks/L3_Vectorized__Environment.html)). In this tutorial, we will focus on ReplayBuffer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OdesCAxANehZ"
   },
   "source": [
    "## Usages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fUbLl9T_SrTR"
   },
   "source": [
    "### Basic usages as a batch\n",
    "Typically, a buffer stores all data in batches, employing a circular-queue mechanism."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mocZ6IqZTH62",
    "outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52"
   },
   "outputs": [],
   "source": [
    "# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n",
    "print(\"========================================\")\n",
    "dummy_buf = ReplayBuffer(size=10)\n",
    "print(dummy_buf)\n",
    "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n",
    "\n",
    "# add 3 steps of data into ReplayBuffer sequentially\n",
    "print(\"========================================\")\n",
    "for i in range(3):\n",
    "    dummy_buf.add(\n",
    "        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n",
    "    )\n",
    "print(dummy_buf)\n",
    "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n",
    "\n",
    "# add another 10 steps of data into ReplayBuffer sequentially\n",
    "print(\"========================================\")\n",
    "for i in range(3, 13):\n",
    "    dummy_buf.add(\n",
    "        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n",
    "    )\n",
    "print(dummy_buf)\n",
    "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H8B85Y5yUfTy"
   },
   "source": [
    "Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cOX-ADOPNeEK",
    "outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6"
   },
   "outputs": [],
   "source": [
    "print(dummy_buf[-1])\n",
    "print(dummy_buf[-3:])\n",
    "# Try more methods you find useful in Batch yourself."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vqldap-2WQBh"
   },
   "source": [
    "ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ppx0L3niNT5K"
   },
   "outputs": [],
   "source": [
    "_dummy_buf = pickle.loads(pickle.dumps(dummy_buf))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Eqezp0OyXn6J"
   },
   "source": [
    "### Understanding reserved keys for buffer\n",
    "As explained above, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain nine reserved keys in Batch.\n",
    "\n",
    "*   `obs`\n",
    "*   `act`\n",
    "*   `rew`\n",
    "*   `terminated`\n",
    "*   `truncated`\n",
    "*   `done`\n",
    "*   `obs_next`\n",
    "*   `info`\n",
    "*   `policy`\n",
    "\n",
    "The meaning of these nine reserved keys are consistent with the meaning in [Gymansium](https://gymnasium.farama.org/index.html#). We would recommend you simply use these nine keys when adding batched data into ReplayBuffer, because\n",
    "some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n",
    "\n",
    "```\n",
    "buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n",
    "buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ueAbTspsc6jo"
   },
   "source": [
    "### Data sampling\n",
    "The primary purpose of maintaining a replay buffer in DRL is to sample data for training. `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fulfill this need."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "P5xnYOhrchDl",
    "outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd"
   },
   "outputs": [],
   "source": [
    "dummy_buf.sample(batch_size=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IWyaOSKOcgK4"
   },
   "source": [
    "## Trajectory tracking\n",
    "Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n",
    "\n",
    "First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "editable": true,
    "id": "H0qRb6HLfhLB",
    "outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b",
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "trajectory_buffer = ReplayBuffer(size=10)\n",
    "# Add the first trajectory (length is 3) into ReplayBuffer\n",
    "print(\"========================================\")\n",
    "for i in range(3):\n",
    "    result = trajectory_buffer.add(\n",
    "        Batch(\n",
    "            obs=i,\n",
    "            act=i,\n",
    "            rew=i,\n",
    "            terminated=1 if i == 2 else 0,\n",
    "            truncated=0,\n",
    "            done=i == 2,\n",
    "            obs_next=i + 1,\n",
    "            info={},\n",
    "        ),\n",
    "    )\n",
    "    print(result)\n",
    "print(trajectory_buffer)\n",
    "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n",
    "\n",
    "# Add the second trajectory (length is 5) into ReplayBuffer\n",
    "print(\"========================================\")\n",
    "for i in range(3, 8):\n",
    "    result = trajectory_buffer.add(\n",
    "        Batch(\n",
    "            obs=i,\n",
    "            act=i,\n",
    "            rew=i,\n",
    "            terminated=1 if i == 7 else 0,\n",
    "            truncated=0,\n",
    "            done=i == 7,\n",
    "            obs_next=i + 1,\n",
    "            info={},\n",
    "        ),\n",
    "    )\n",
    "    print(result)\n",
    "print(trajectory_buffer)\n",
    "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n",
    "\n",
    "# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n",
    "print(\"========================================\")\n",
    "for i in range(8, 13):\n",
    "    result = trajectory_buffer.add(\n",
    "        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=False, obs_next=i + 1, info={}),\n",
    "    )\n",
    "    print(result)\n",
    "print(trajectory_buffer)\n",
    "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dO7PWdb_hkXA"
   },
   "source": [
    "### Episode length and rewards tracking\n",
    "Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xbVc90z8itH0"
   },
   "source": [
    "### Episode index management\n",
    "In the ReplayBuffer above, we can get access to any data step by indexing.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4mKwo54MjupY",
    "outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961"
   },
   "outputs": [],
   "source": [
    "print(trajectory_buffer)\n",
    "print(\"========================================\")\n",
    "\n",
    "data = trajectory_buffer[6]\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p5Co_Fmzj8Sw"
   },
   "source": [
    "We know that step \"6\" is not the start of an episode - which should be step \"3\", since \"3-7\" is the second trajectory we add into the ReplayBuffer - but we wonder how do we get the earliest index of that episode.\n",
    "\n",
    "This may seem easy but actually it is not. We cannot simply look at the \"done\" flag preceding the start of a new episode, because since the third-added trajectory is not finished yet, step \"3\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty when using more advanced ReplayBuffer like VectorReplayBuffer, since it does not store the data in a simple circular-queue.\n",
    "\n",
    "Luckily, all ReplayBuffer instances help you identify step indexes through a unified API. One can simply input an array of indexes and look for their previous index in the episode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# previous step of indexes [0, 1, 2, 3, 4, 5, 6] are:\n",
    "print(trajectory_buffer.prev(np.array([0, 1, 2, 3, 4, 5, 6])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4Wlb57V4lQyQ"
   },
   "source": [
    "Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us identify the last index of an episode regardless of which kind of ReplayBuffer we are using."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "zl5TRMo7oOy5",
    "outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb"
   },
   "outputs": [],
   "source": [
    "# next step of indexes [4,5,6,7,8,9] are:\n",
    "print(trajectory_buffer.next(np.array([4, 5, 6, 7, 8, 9])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YJ9CcWZXoOXw"
   },
   "source": [
    "We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Xkawk97NpItg",
    "outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd"
   },
   "outputs": [],
   "source": [
    "print(trajectory_buffer.unfinished_index())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8_lMr0j3pOmn"
   },
   "source": [
    "Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FEyE0c7tNfwa"
   },
   "source": [
    "## Further Reading\n",
    "### Other Buffer Module\n",
    "\n",
    "*   PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n",
    "*   CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n",
    "*   ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n",
    "\n",
    "Refer to the documentation and source code for further details.\n",
    "\n",
    "### Support for steps stacking to use RNN in DRL.\n",
    "There is an option called `stack_num` (default to 1) when initializing the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
